package asm.dispatch.test;

import org.ricks.asm.Handle;
import org.ricks.asm.Label;
import org.ricks.asm.MethodVisitor;
import org.ricks.asm.Opcodes;
import org.ricks.common.Tuple;
import org.ricks.ioc.App;
import org.ricks.ioc.AppContext;
import org.ricks.ioc.Bean;
import org.ricks.net.ActionMethod;
import org.ricks.ioc.bean.factory.AbstractBeanFactory;
import org.ricks.log.Logger;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Locale;
import java.util.stream.Collectors;

import static org.ricks.asm.Opcodes.*;

/**
 * @author chenwei
 * @Description:构建函数也是方法
 * @date 2022/11/2915:36
 */
public class DispatchMethodVisitor extends MethodVisitor {

    static List<Tuple> tupleList = initPairs();

    protected DispatchMethodVisitor(MethodVisitor mv) {
        super(ASM7,mv);
    }

    @Override
    public void visitCode() {

//        Label label = new Label();
//        mv.visitLabel(label);
//        mv.visitLineNumber(15, label);
//        mv.visitFieldInsn(GETSTATIC, "java/lang/System", "err", "Ljava/io/PrintStream;");
//        mv.visitLdcInsn("AMS \u521d\u59cb\u72b6\u6001\u3002\u3002\u3002\u3002\u3002\u3002\u3002\u3002\u3002\u3002");
//        mv.visitMethodInsn(INVOKEVIRTUAL, "java/io/PrintStream", "println", "(Ljava/lang/String;)V", false);

        Label label0 = new Label();
        mv.visitLabel(label0);
        mv.visitLineNumber(23, label0);
        mv.visitVarInsn(ALOAD, 0);
        mv.visitMethodInsn(INVOKEVIRTUAL, "org/ricks/common/Context", "getCmd", "()Ljava/lang/Object;", false);
        mv.visitTypeInsn(CHECKCAST, "java/lang/Short");
        mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/Short", "shortValue", "()S", false);

        mv.visitVarInsn(ISTORE, 1);
        Label label1 = new Label();
        mv.visitLabel(label1);
        mv.visitLineNumber(24, label1);
        mv.visitVarInsn(ILOAD, 1);

        tupleList = tupleList.stream().sorted(Comparator.comparing(tuple -> tuple.get(0))).collect(Collectors.toList()); //妈的，浪费老子那么多时间。switch 必须要从小到大排序不然生成的class字节码会进入 default 默认。fuck you
        int[] messageIds = new int[tupleList.size()];
        Label[] labels = new Label[tupleList.size()];
        for (int i = 0; i < tupleList.size(); i++) {
            labels[i] = new Label();
            messageIds[i] = Integer.valueOf(tupleList.get(i).get(0).toString());
        }
        Label defaultLabel  = new Label();
        Label gotoLabel  = new Label();
        mv.visitLookupSwitchInsn(defaultLabel, messageIds, labels);
        int currLineNum = 24;
        for (int i = 0; i < tupleList.size(); i++) {
            currLineNum += 1;
            createMethodFunction(mv,tupleList.get(i), labels[i],currLineNum,gotoLabel,i==0);
        }

        mv.visitLabel(defaultLabel);
        mv.visitLineNumber(currLineNum + 1, defaultLabel);
        mv.visitFrame(Opcodes.F_SAME, 0, null, 0, null);
        mv.visitFieldInsn(GETSTATIC, "java/lang/System", "err", "Ljava/io/PrintStream;");
        mv.visitVarInsn(ALOAD, 0);
        mv.visitMethodInsn(INVOKEVIRTUAL, "org/ricks/common/Context", "getCmd", "()Ljava/lang/Object;", false);
        mv.visitInvokeDynamicInsn("makeConcatWithConstants", "(Ljava/lang/Object;)Ljava/lang/String;", new Handle(Opcodes.H_INVOKESTATIC, "java/lang/invoke/StringConcatFactory", "makeConcatWithConstants", "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;Ljava/lang/String;[Ljava/lang/Object;)Ljava/lang/invoke/CallSite;", false), new Object[]{"\u0001 \u672a\u77e5\u6307\u4ee4"});
        mv.visitMethodInsn(INVOKEVIRTUAL, "java/io/PrintStream", "println", "(Ljava/lang/String;)V", false);
        mv.visitLabel(gotoLabel);
        mv.visitLineNumber(currLineNum + 5, gotoLabel);
        mv.visitFrame(Opcodes.F_SAME, 0, null, 0, null);
        mv.visitInsn(RETURN);
        Label label14 = new Label();
        mv.visitLabel(label14);
        mv.visitLocalVariable("context", "Lorg/demon/common/Context;", "Lorg/demon/common/Context<Ljava/lang/Short;[Ljava/lang/Byte;>;", label0, label14, 0);
        mv.visitLocalVariable("cmd", "S", null, label1, label14, 1);
    }

    @Override
    public void visitMaxs(int maxStack, int maxLocals) {
        super.visitMaxs(maxStack+1, maxLocals);
    }

    private static void  createMethodFunction(MethodVisitor methodVisitor, Tuple tuple, Label label, int lineNum, Label gotoLabel,boolean isFirst) {
        String className = tuple.get(1).toString();
        String cName = className.replace(".","/");
        String objName = className.substring(className.lastIndexOf(".") + 1,className.length()).toLowerCase(Locale.ROOT);
        methodVisitor.visitLabel(label);
        methodVisitor.visitLineNumber(lineNum, label);
        if(isFirst) methodVisitor.visitFrame(Opcodes.F_APPEND, 1, new Object[]{Opcodes.INTEGER}, 0, null); else methodVisitor.visitFrame(Opcodes.F_SAME, 0, null, 0, null);
        methodVisitor.visitFieldInsn(GETSTATIC, "org/ricks/dispatch/Dispatcher", objName, "L"+cName+";");
        methodVisitor.visitVarInsn(ALOAD, 0);
        methodVisitor.visitMethodInsn(INVOKEVIRTUAL, cName, tuple.get(2), "(Lorg/demon/common/Context;)V", false);

//        methodVisitor.visitFieldInsn(GETSTATIC, "java/lang/System", "err", "Ljava/io/PrintStream;");
//        methodVisitor.visitLdcInsn("\u8fdb\u5165\u539f\u751f\u6001\u3002\u3002\u3002\u3002\u3002\u3002\u3002\u3002");
//        methodVisitor.visitMethodInsn(INVOKEVIRTUAL, "java/io/PrintStream", "println", "(Ljava/lang/String;)V", false);

        methodVisitor.visitJumpInsn(GOTO, gotoLabel);

    }

    private static List<Tuple> initPairs() {

        App context = AppContext.getApplication();
        List<Tuple> pairs = new ArrayList<>();
        AbstractBeanFactory beanFactory = context.getBeanFactory();
        if(beanFactory != null) {
            List<Object> beans = beanFactory.getBeansForAnnotation(Bean.class);
            if (!beans.isEmpty()) {
                for (Object obj : beans) {
                    try {
                        Method[] methods = obj.getClass().getMethods();
                        for (Method method : methods) {
                            ActionMethod messageHandler = method.getAnnotation(ActionMethod.class);
                            if (messageHandler == null) {
                                continue;
                            }
                            Class<?>[] parameterClazzes = method.getParameterTypes();
                            if (parameterClazzes.length != 1) {
                                throw new IllegalArgumentException("消息处理方法的参数不正确，参数必须是一个。");
                            }
                            Class<?> commandClass = parameterClazzes[0];
                            if (!org.ricks.common.Context.class.isAssignableFrom(commandClass)) {
                                continue;
                            }
                            pairs.add(new Tuple(messageHandler.messageId(),getClassName(obj).replace("/","."),method.getName()));
                        }
                    } catch (Exception e) {
                        Logger.error("", e);
                    }
                }
            }
        }
        return pairs;
    }

    private static String getClassName(Object obj) {
        String beanName = obj.getClass().getName();
        return beanName.replace(".","/");
    }
}
